#!/usr/bin/python3
# -*- coding: utf-8 -*-
import json
import os
from datasets import load_dataset, Dataset, DatasetDict
import pandas as pd
import logging
import random
import pdb

logger = logging.getLogger(__name__)


class ABC:
    name = "base"
    ice_separator = None
    question_field = None
    answer_field = None
    hf_dataset = None
    hf_dataset_name = None
    field_getter = None
    a_prefix = ""  # an answer prefix as beginning word, e.g., SELECT

    def __init__(self, dataset_path=None, dataset_split=None, ds_size=None, ds_segment=None, lang_name=None, high_lang_name=None, all_data=None,test_lang_name=None ):
        if dataset_path is None or not os.path.exists(dataset_path):
            print('Initializing dataset wrapper: {}'.format(self.name))
            self.dataset = load_dataset(self.hf_dataset, self.hf_dataset_name, lang_name=lang_name, high_lang_name=high_lang_name, trust_remote_code=True)
        else:
            self.dataset = Dataset.from_pandas(pd.read_json(dataset_path))
            logger.info(f"Loading dataset from {dataset_path}, size {len(self.dataset)}")

        if dataset_split is not None and isinstance(self.dataset, DatasetDict):
            self.dataset = self.dataset[dataset_split]

        if ds_size is not None:
            self.dataset = load_partial_dataset(self.dataset, size=ds_size, segment=ds_segment)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __len__(self):
        return len(self.dataset)

    def get_field(self, entry, field):
        return self.field_getter[field](entry)

    def get_corpus(self, field):
        return [self.get_field(entry, field) for entry in self.dataset]


def load_partial_dataset(dataset, size=1, segment=None):
    """
    Load a partial dataset based on the specified size and segment.

    - size: number of examples to return.
    - segment: selects the segment-th chunk of 'size' items from the dataset.

    Example:
    size = 100 
    segment = 3 
    will return items from index 300 to 399 (100 * 3 to 100 * (3 + 1))
    """
    total_size = len(dataset)

    # If full dataset requested or size is invalid
    if size <= 0 or size >= total_size:
        return dataset

    start_idx = 0
    if segment is not None:
        start_idx = size * segment

    end_idx = min(start_idx + size, total_size)
    
    return dataset.select(range(start_idx, end_idx))


# def load_partial_dataset(dataset, size=1):
#     if size == 1 or size >= len(dataset):
#         return dataset

#     total_size = len(dataset)
#     size = int(size * total_size) if size < 1 else size

#     rand = random.Random(x=31) # make it random int 
#     index_list = list(range(total_size))
#     rand.shuffle(index_list)
#     dataset = dataset.select(index_list[:size])
#     return dataset